// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3
#pragma once

/** @file linear.h
	Linear algebra functions.
	*/

// include files
#include <cstddef>
#include <vector>
#include <array>
#include <stdexcept>
#include <cmath>

namespace math
{
/** real type */
using real = double;

/** array of reals */
template<size_t n> using Array = std::array<real, n>;

/** matrix of reals */
template<size_t n, size_t m> using Matrix = std::array< std::array<real, m>, n>;

/** array of indices */
template<size_t n> using ArrayIndex = std::array<int, n>;

/** Decompose matrix using lower-upper method.
@param a matrix to be decomposed in place
@param index row permutation vector
@param d row permutation parity
*/
template<size_t N> void luDecompose(Matrix<N, N> &a, ArrayIndex<N> &index, real &d)
{
	// initialize scaling
	Array<N> vv;
	for (size_t i = 0; i < N; ++i)
	{
		real scale = 0.0;
		for (size_t j = 0; j < N; ++j)
			if (std::abs(a[i][j]) > scale)
				scale = std::abs(a[i][j]);
		if (scale == 0.0)
			throw std::domain_error("singularity");
		vv[i] = 1/scale; // Save the scaling.
	}

	// for each column using Crout's method
	d = 1.0;
	for (size_t j = 0; j < N; ++j)
	{
		size_t i;

		// reduce column
		for (i = 0; i < j; i++)
			for (size_t k = 0; k < i; k++)
				a[i][j] -= a[i][k]*a[k][j];

		// find largest pivot
		size_t iMaxPivot = j;
		real maxPivot = 0.0;
		for (i = j; i < N; i++)
		{
			for (size_t k = 0; k < j; k++)
				a[i][j] -= a[i][k]*a[k][j];

			// better pivot?
			real pivot = vv[i]*a[i][j];
			if ((pivot = std::max(-pivot, pivot)) > maxPivot)
			{
				maxPivot = pivot;
				iMaxPivot = i;
			}
		}

		// interchange rows
		if (iMaxPivot != j)
		{
			// interchange rows
			for (size_t k = 0; k < N; k++)
				std::swap(a[iMaxPivot][k], a[j][k]);
			d = -d;
			vv[iMaxPivot] = vv[j];
		}

		// store permutation
		index[j] = iMaxPivot;

		// set tiny pivot for singular matrices
		const real tiny = 1.0e-20;
		if (a[j][j] == 0.0)
			a[j][j] = tiny;

		// scale by pivot
		for (i = j+1; i < N; i++)
			a[i][j] /= a[j][j];
	}
}

/** Back substitute a matrix.
@param a LU decomposed matrix
@param index row permutation vector
@param b vector to be back substituted
*/
template<size_t N> void luBackSubstitute(const Matrix<N, N> &a, const ArrayIndex<N> &index, Array<N> &b)
{
	size_t i;

	// skip zero elements
	size_t ii = N;
	for (i = 0; i < N; i++)
	{
		// unscramble permutation
		std::swap(b[index[i]], b[i]);

		// forward substition
		if (ii < N)
			for (size_t j = ii; j < i; j++)
				b[i] -= a[i][j]*b[j];

		// else skip
		else if (b[i])
			ii = i;
	}

	// back substition
	i = N;
	do
	{
		--i;
		for (size_t j = i+1; j < N; j++)
			b[i] -= a[i][j]*b[j];
		b[i] /= a[i][i];
	}
	while (i > 0);
}

/** Decompose positive definite matrix using Cholesky method.
@param a matrix to be decomposed in place
*/
template<size_t N> void choleskyDecompose(Matrix<N, N> &a)
{
	// for each row
	for (size_t i = 0; i < N; ++i)
	{
		for (size_t k = 0; k < i; ++k)
			a[i][i] -= a[i][k]*a[i][k];
		if (a[i][i] <= 0.) throw std::domain_error("not positive definite");
		a[i][i] = sqrt(a[i][i]);
		real d = 1/a[i][i];
		for (size_t j = i+1; j < N; ++j)
		{
			for (size_t k = 0; k < i; ++k)
				a[j][i] -= a[i][k]*a[j][k];
			a[j][i] = a[j][i]*d;
			a[i][j] = 0.;
		}
	}
}
}
